import random
import gym
from collections import OrderedDict
from gym import error, spaces, utils
from gym.utils import seeding
import hashlib
import os

import numpy as np
from gym_malware.envs.utils import interface, pefeatures

from gym_malware.envs.controls import manipulate2 as manipulate
ACTION_LOOKUP = {i: act for i, act in enumerate(
    manipulate.ACTION_TABLE.keys())}

# change this to function to the machine learning model you wish to attack
# function should be of the form
# def label_function( bytez ):
#    # returns score # number between 0.0 and 1.0, with benign=0.0 and malware=1.0
score_function = interface.get_score_local
# change this threshold to be the malicious/benign threshold
malicious_threshold = interface.local_model_threshold # if specified in params.json, then this is  interface.__private_data['threshold']

class MalwareScoreEnv(gym.Env):
    metadata = {'render.modes': ['human']}

    def __init__(self, sha256list, random_sample=True, maxturns=3, output_path='evaded/score/', cache=False):
        self.cache = cache
        self.available_sha256 = sha256list
        self.action_space = spaces.Discrete(len(ACTION_LOOKUP))
        self.maxturns = maxturns
        self.feature_extractor = pefeatures.PEFeatureExtractor()
        self.random_sample = random_sample
        self.sample_iteration_index = 0
        self.output_path = os.path.join(
            os.path.dirname(
                os.path.dirname(
                    os.path.dirname(
                        os.path.abspath(__file__)))), output_path)
        if not os.path.exists(output_path):
            os.makedirs(output_path)

        self.history = OrderedDict()

        self.samples = {}
        if self.cache:
            for sha256 in self.available_sha256:
                try:
                    self.bytez = interface.fetch_file(self.sha256)
                except interface.FileRetrievalFailure:
                    print("failed fetching file")
                    continue  # try a new sha256...this one can't be retrieved from storage

        self._reset() # self.original_score, self.bytez and self.observation_space get set here

    def _step(self, action_index):
        self.turns += 1
        self._take_action(action_index)

        # get reward
        try:
            self.score = score_function(self.bytez)
        except interface.ClassificationFailure:
            print("Failed to classify file")
            episode_over = True
        else:
            self.observation_space = self.feature_extractor.extract(self.bytez)
            if self.score < malicious_threshold:
                # we win!
                reward = 10.0
                episode_over = True
                self.history[self.sha256]['evaded'] = True

                # store sample to output directory
                m = hashlib.sha256()
                m.update( self.bytez )
                sha256 = m.hexdigest()
                self.history[self.sha256]['evaded_sha256'] = sha256      
                          
                with open( os.path.join( self.output_path, sha256), 'wb') as outfile:
                    outfile.write( self.bytez )

            elif self.turns >= self.maxturns:
                reward = self.original_score - self.score #
                episode_over = True
            else:
                reward = self.original_score - self.score # intermediate rewards
                episode_over = False

        if episode_over:
            print("episode is over: reward = {}!".format(reward))

        return self.observation_space, reward, episode_over, {}

    def _take_action(self, action_index):
        assert action_index < len(ACTION_LOOKUP)
        action = ACTION_LOOKUP[action_index]
        print(action)
        self.history[self.sha256]['actions'].append(action)
        self.bytez = bytes(
            manipulate.modify_without_breaking(self.bytez, [action]))

    def _reset(self):
        self.turns = 0
        while True:
            # get the new environment
            if self.random_sample:
                self.sha256 = random.choice(self.available_sha256)
            else: # draw a sample at random
                self.sha256 = self.available_sha256[ self.sample_iteration_index % len(self.available_sha256) ]
                self.sample_iteration_index += 1

            self.history[self.sha256] = {'actions': [], 'evaded': False}
            if self.cache:
                self.bytez = self.samples[self.sha256]
            else:
                try:
                    self.bytez = interface.fetch_file(self.sha256)
                except interface.FileRetrievalFailure:
                    print("failed fetching file")
                    continue  # try a new sha256...this one can't be retrieved from storage

            self.original_score = score_function(self.bytez)
            if self.original_score < malicious_threshold:
                # skip this one, it's already benign, and the agent will learn nothing
                continue

            print("new sha256: {}".format(self.sha256))

            self.observation_space = self.feature_extractor.extract(self.bytez)
            print("original score: {}".format(self.original_score))

            break  # we're done here

        return np.asarray(self.observation_space)

    def _render(self, mode='human', close=False):
        pass
